import os
import argparse
import torch
import json
from image_synthesis.utils.misc import instantiate_from_config
from image_synthesis.modeling.modules.clip.simple_tokenizer import SimpleTokenizer
from image_synthesis.data.tsv_dataset import TSVTextDataset
from image_synthesis.utils.io import save_config_to_yaml

def conceptual_caption(data_root='', save_dir='data/captions'):
    val_config = {
        "target": "image_synthesis.data.tsv_dataset.TSVImageTextDataset",
        "params": {
            'data_root': data_root,
            'name': 'conceptualcaption/val',
            "image_tsv_file": ['gcc-val-image.tsv'],
            "text_tsv_file": ['gcc-val-text.tsv'],
            "text_format": "json",
            "im_preprocessor_config": {
                "target": "image_synthesis.data.utils.image_preprocessor.SimplePreprocessor",
                "params": {
                    "size": 256,
                },
            },
        },
    }

    train_config = {
        "target": "image_synthesis.data.tsv_dataset.TSVImageTextDataset",
        "params": {
            'data_root': data_root,
            'name': 'conceptualcaption/train',
            "image_tsv_file":['gcc-train-image-00.tsv','gcc-train-image-01.tsv'],
            "text_tsv_file": ['gcc-train-text-00.tsv', 'gcc-train-text-01.tsv'],
            "text_format": "json",
            "im_preprocessor_config": {
                "target": "image_synthesis.data.utils.image_preprocessor.SimplePreprocessor",
                "params": {
                    "size": 256,
                },
            },
        },
    }

    val_dataset = instantiate_from_config(val_config)
    # train_dataset = instantiate_from_config(train_config)
    # datasets = [val_dataset, train_dataset]

    datasets = [val_dataset]

    os.makedirs(save_dir, exist_ok=True)
    save_file = os.path.join(save_dir, 'conceptual_caption.txt')
    fw = open(save_file, 'w')
    batch_size = 8
    print('Prepare dataset done!')
    for dataset in datasets:
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=16,
            drop_last=False
        )

        for i, data in enumerate(dataloader):
            if i % 100 == 0:
                print("{}/{}".format(i, len(dataset)//batch_size))
            captions = data['text']
            captions = '\n'.join(captions)
            fw.write(captions+'\n')
    fw.close()
    print('saved in {}'.format(save_file))

    # get_statics
    statics_of_captions(save_file)

def statics_of_captions(caption_file, subword_end_idx=49152, add_start_and_end=True):
    tokenizer = SimpleTokenizer(subword_end_idx)
    
    sot_token = [tokenizer.encoder["<|startoftext|>"]] if add_start_and_end else []
    eot_token = [tokenizer.encoder["<|endoftext|>"]] if add_start_and_end else []

    token_statics = {}
    count = 0
    with open(caption_file, 'r') as cf:
        lines = cf.readlines()
        for caption in lines:
            tokens = sot_token + tokenizer.encode(caption.lower()) + eot_token
            if len(tokens) not in token_statics:
                token_statics[len(tokens)] = 0
            token_statics[len(tokens)] += 1
            count += 1
            if count % 100 == 0:
                print(count, caption)

    # save the info
    # json_path = '.'.join(caption_file.split('.')[:-1] + ['{}_statics.yaml'.format(subword_end_idx)])
    json_path = caption_file.replace('.txt', '_{}_statics.yaml'.format(subword_end_idx))
    save_config_to_yaml(token_statics, json_path)
    print('statics saved in {}'.format(json_path))



def get_args():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--data_root', type=str, default='data', 
                        help='dir of datasets')
    parser.add_argument('--save_dir', type=str, default='data/captions', 
                        help='dir to save captions')
    

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_args()
    conceptual_caption(data_root=args.data_root, save_dir=args.save_dir)
    # statics_of_captions('data/captions/conceptual_caption.txt', subword_end_idx=16384)

    # tophost_portrait(data_root=args.data_root, save_dir=args.save_dir)
    # statics_of_captions('data/captions/tophost_portrait.txt')


    # multi_modal_celeba_hq(data_root=args.data_root, save_dir=args.save_dir)